import os
import time
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, average_precision_score, precision_score
import numpy as np
import sys
import pickle
import networkx as nx
from model1 import GAE

num_type = 91
num_event_rela_types = 1
num_event_entity_rela_types = 85
num_entity_rela_types = 46

def get_dataset(data_type):
    [all_A_init, all_A_true, all_x_features] = pickle.load(
        open('./data/' + data_type + "_pruned_with_bert_max_50_set_part_1.pkl", "rb"))
    return all_A_true, all_A_init, all_x_features




class GAE_dataset(Dataset):
    def __init__(self, features, adj_true, adj_init):
        self.features = features
        self.adj_true = adj_true
        self.adj_init = adj_init

    def __len__(self):
        return len(self.features)

    def __getitem__(self, index):
        return self.features[index], self.adj_true[index], self.adj_init[index]

train_adj_true_all, train_adj_init_all, train_features_all = get_dataset(
    "train")
dev_adj_true_all, dev_adj_init_all, dev_features_all = get_dataset(
    "dev")
test_adj_true_all, test_adj_init_all, test_features_all = get_dataset(
    "test")

train_set = GAE_dataset(
    train_features_all, train_adj_true_all, train_adj_init_all)
dev_set = GAE_dataset(
    dev_features_all, dev_adj_true_all, dev_adj_init_all)
test_set = GAE_dataset(
    test_features_all, test_adj_true_all, test_adj_init_all)

train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
dev_loader = DataLoader(dev_set, batch_size=1, shuffle=False)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

def get_acc(adj_pred, adj_true, epoch):
#     if epoch >= 5:
#         print(adj_pred)
#     adj_pred = F.softmax(adj_pred, dim=1).detach().cpu().numpy()[:, 1]
    adj_pred = torch.sigmoid(adj_pred).view(-1).detach().cpu().numpy()
    adj_true = adj_true.view(-1).cpu().numpy()
    try:
        auc = roc_auc_score(adj_true, adj_pred)
    except ValueError:
        auc = 0.
    adj_pred = adj_pred > 0.5
    considered_true = []
    considered_pred = []
    true_pos = 0
    true_neg = 0
    false_pos = 0
    false_neg = 0
    for ii in range(len(adj_true)):
        if adj_true[ii] != 0:
            considered_true.append(adj_true[ii])
            considered_pred.append(adj_pred[ii])
    for ii in range(len(adj_true)):
        if adj_true[ii] == 1:
            if adj_pred[ii] == 1:
                true_pos += 1
            else:
                true_neg += 1
        else:
            if adj_pred[ii] == 1:
                false_pos += 1
            else:
                false_neg += 1
    if len(considered_true) == 0:
        print("here")
    considered_true = np.array(considered_true)
    considered_pred = np.array(considered_pred)
    acc_filtered = (considered_pred == considered_true).sum() / len(
        considered_true)
    acc_all = (adj_pred == adj_true).sum() / len(adj_true)
    prec = precision_score(adj_true, adj_pred)
    
#     true_pos /= 1.0 * len(adj_pred)
#     true_neg /= 1.0 * len(adj_pred)
#     false_pos /= 1.0 * len(adj_pred)
#     false_neg /= 1.0 * len(adj_pred)
    
    
    return true_pos, true_neg, false_pos, false_neg, auc



device = torch.device("cuda:2")
# input_dim = num_type
input_dim = 91
# hidden1_dim = 128
# hidden2_dim = 64


# model = model.float()

# hidden1_dims = [256, 128, 64]
# hidden2_dims = [128, 64, 32]
hidden1_dims = [256]
hidden2_dims = [64]
# learning_rates = [1e-4, 1e-5]
learning_rates = [1e-5]

for hidden1_dim in hidden1_dims:
    for hidden2_dim in hidden2_dims:
        for learning_rate in learning_rates:
            model = GAE(input_dim, hidden1_dim, hidden2_dim)
            model.to(device)
            model.train()
            optimizer = AdamW(model.parameters(), lr=learning_rate)

            best_test_auc = 0.
            for epoch in range(400):
                all_loss = 0.
                all_train_diff = 0.
                all_train_true_pos, all_train_true_neg, all_train_false_pos, all_train_false_neg, all_train_auc = 0., 0., 0., 0., 0.
                for ii, (features, adj_true, adj_init) in enumerate(train_loader):
                    optimizer.zero_grad()
                    features = features[0].float().to(device)
                    adj_init = adj_init[0].float().to(device)
                    if adj_init.shape[0] == 0:
                        continue

                    adj_true = adj_true.to(device).view(-1)
                    
                    adj_pred = model(features, adj_init).view(-1)                    
                    neg_mask = torch.nonzero(adj_true - 1)
                    neg_mask = neg_mask[torch.randperm(neg_mask.size()[0])][:len(torch.nonzero(adj_true))]
                    pos_mask = torch.nonzero(adj_true)
                    neg_loss = F.binary_cross_entropy_with_logits(adj_pred[neg_mask], adj_true.float()[neg_mask])
                    pos_loss = F.binary_cross_entropy_with_logits(adj_pred[pos_mask], adj_true.float()[pos_mask])
                    loss = neg_loss + pos_loss
                    loss.backward()
                    optimizer.step()
                    all_loss += loss.cpu().item()
                    all_train_diff += abs(torch.sum(adj_true).cpu().item() - torch.sum(torch.sigmoid(adj_pred).detach() > 0.5).cpu().item())

                    adj_pred_new = torch.cat((adj_pred[pos_mask], adj_pred[neg_mask]))
                    adj_true_new = torch.cat((adj_true[pos_mask], adj_true[neg_mask]))
                    train_true_pos, train_true_neg, train_false_pos, train_false_neg, train_auc = get_acc(adj_pred_new, adj_true_new, epoch)
                    all_train_true_pos += train_true_pos
                    all_train_true_neg += train_true_neg
                    all_train_false_pos += train_false_pos
                    all_train_false_neg += train_false_neg
                    all_train_auc += train_auc
                    # if ii % 20 == 0:
                    #     print("curr loss: " + str(loss))
                all_train_f1 = 2 * all_train_true_pos / (2 * all_train_true_pos + all_train_false_pos + all_train_false_neg)
                all_train_true_pos /= len(train_loader)
                all_train_true_neg /= len(train_loader)
                all_train_false_pos /= len(train_loader)
                all_train_false_neg /= len(train_loader)
                all_train_auc /= len(train_loader)
                
                print("epoch " + str(epoch) + " loss: " + str(all_loss) + " diff: " + str(round(all_train_diff, 4)) + " train tp: " + str(round(all_train_true_pos, 4)) + " tn: " + str(round(all_train_true_neg, 4)) + " fp: " + str(round(all_train_false_pos, 4)) + " fn: " + str(round(all_train_false_neg, 4)) + " auc: " + str(round(all_train_auc, 4)) + " f1: " + str(round(all_train_f1, 4)))
                all_test_true_pos, all_test_true_neg, all_test_false_pos, all_test_false_neg, all_test_auc = 0., 0., 0., 0., 0.
                for ii, (features, adj_true, adj_init) in enumerate(test_loader):
                    features = features[0].float().to(device)
                    adj_init = adj_init[0].float().to(device)
                    if adj_init.shape[0] == 0:
                        continue

                    adj_true = adj_true.to(device).view(-1)
                    
                    adj_pred = model(features, adj_init).view(-1)                    
                    neg_mask = torch.nonzero(adj_true - 1)
                    neg_mask = neg_mask[torch.randperm(neg_mask.size()[0])][:len(torch.nonzero(adj_true))]
                    pos_mask = torch.nonzero(adj_true)

                    adj_pred_new = torch.cat((adj_pred[pos_mask], adj_pred[neg_mask]))
                    adj_true_new = torch.cat((adj_true[pos_mask], adj_true[neg_mask]))
                    test_true_pos, test_true_neg, test_false_pos, test_false_neg, test_auc = get_acc(adj_pred_new, adj_true_new, epoch)
                    
#                     features = features[0].float().to(device)
#                     adj_init = adj_init[0].float().to(device)
#                     adj_true = adj_true.long().to(device)
#                     adj_pred = model(features, adj_init)
#                     test_true_pos, test_true_neg, test_false_pos, test_false_neg, test_auc = get_acc(adj_pred, adj_true, epoch)
                    all_test_true_pos += test_true_pos
                    all_test_true_neg += test_true_neg
                    all_test_false_pos += test_false_pos
                    all_test_false_neg += test_false_neg
                    all_test_auc += test_auc
                    # if ii % 20 == 0:
                    #     print("curr loss: " + str(loss))
                all_test_f1 = 2 * all_test_true_pos / (2 * all_test_true_pos + all_test_false_pos + all_test_false_neg)
                all_test_true_pos /= len(test_loader)
                all_test_true_neg /= len(test_loader)
                all_test_false_pos /= len(test_loader)
                all_test_false_neg /= len(test_loader)
                all_test_auc /= len(test_loader)
                test_acc = (all_test_true_pos + all_test_false_neg) / (all_test_true_pos + all_test_true_neg + all_test_false_pos + all_test_false_neg)
                if all_test_f1 >= best_test_auc and epoch >= 30:
                    best_test_auc = all_test_f1
                    torch.save(model.state_dict(), "best_model_part_1.pt")
                print("epoch " + str(epoch) + " test tp: " + str(round(all_test_true_pos, 4)) + " tn: " + str(round(all_test_true_neg, 4)) + " fp: " + str(round(all_test_false_pos, 4)) + " fn: " + str(round(all_test_false_neg, 4)) + " auc: " + str(round(all_test_auc, 4)) + " f1: " + str(round(all_test_f1, 4)) + " acc: " + str(round(test_acc, 4)))
            print("hidden1_dim: " + str(hidden1_dim) + " hidden2_dim: " + str(hidden2_dim) + " lr: " + str(learning_rate) + " best test auc: " + str(round(best_test_auc, 4)))




    
    








































